#!/usr/bin/env python3
# Standalone launcher script to setup mounts and GPUs accordingly, and 
# pass through arguments into a docker image.
# Mirrors some arguments from the argument parser at src/AIMD/arguments.py
import argparse
import os
import subprocess
import sys
from os import path as osp

_docker_image = "ghcr.io/microsoft/ai2bmd:latest"
_docker_image_data1 = "ghcr.io/microsoft/ai2bmd:training-data1"
_docker_image_data2 = "ghcr.io/microsoft/ai2bmd:training-data2"
_docker_image_data3 = "ghcr.io/microsoft/ai2bmd:training-data3"
_docker_image_data4 = "ghcr.io/microsoft/ai2bmd:training-data4"
_docker_image_data5 = "ghcr.io/microsoft/ai2bmd:training-data5"
_docker_image_data6 = "ghcr.io/microsoft/ai2bmd:training-data6"
_src_dir = "/ai2bmd"

def parse_args(argv):
    # origianl arguments
    parser = argparse.ArgumentParser(description="DL Molecular Simulation.")
    parser.add_argument(
        "--base-dir",
        type=str,
        default=os.getcwd(),
        help="A directory for running simulation",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default=None,
        help="A directory for saving results",
    )
    parser.add_argument(
        "--ckpt-path",
        type=str,
        default=osp.join(_src_dir, "ViSNet/checkpoints"),
        help="A directory including well-trained pytorch models",
    )
    parser.add_argument(
        "--ckpt-type",
        type=str,
        default="07ffeb8144acdc95aa87f90f65e1457a",
        choices=["07ffeb8144acdc95aa87f90f65e1457a"],
        help="Checkpoint type, which is the md5sum of the model checkpoint file",
    )
    parser.add_argument(
        "--prot-file",
        type=str,
        required=True,
        help="Protein file for simulation",
    )
    parser.add_argument(
        "--temp-k",
        type=int,
        default=300,
        help="Simulation temperature in Kelvin",
    )
    parser.add_argument(
        "--sim-steps",
        type=int,
        default=100,
        help="Simulation steps for simulation",
    )
    parser.add_argument(
        "--timestep",
        type=float,
        default=1,
        help="TimeStep (fs) for simulation",
    )
    parser.add_argument(
        "--preeq-steps",
        type=int,
        default=100,
        help="Pre-equilibration simulation steps for each constraint",
    )
    parser.add_argument(
        "--max-cyc",
        type=int,
        default=100,
        help="Maximum energy minimization cycles in preprocessing",
    )
    parser.add_argument(
        "--constraints",
        action='store_true',
        help="Constrain hydrogen bonds",
    )
    parser.add_argument(
        "--no-constraints",
        action='store_false',
        dest='constraints',
        help="Do not constrain hydrogen bonds",
    )
    parser.add_argument(
        "--mm-method",
        type=str,
        default="tinker-GPU",
        choices=["tinker", "tinker-GPU"],
        help="MM calculator for the nonbonded energy",
    )
    parser.add_argument(
        "--solvent",
        action='store_true',
        help="Use solvent",
    )
    parser.add_argument(
        "--no-solvent",
        action='store_false',
        dest='solvent',
        help="Do not use solvent",
    )
    parser.add_argument(
        "--solvent-method",
        type=str,
        default="AMOEBA",
        choices=["AMOEBA"],
        help="Method to use for preprocessing the protein",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for simulation",
    )
    parser.add_argument(
        "--restart",
        action='store_true',
        help="Restart the simulation",
    )
    parser.add_argument(
        "--no-restart",
        action='store_false',
        dest='restart',
        help="Do not restart the simulation",
    )
    parser.add_argument(
        "--build-frames",
        action='store_true',
        help="Build xyz frames from the trajectory after simulation",
    )
    parser.add_argument(
        "--no-build-frames",
        action='store_false',
        dest='build_frames',
        help="Do not build xyz frames from the trajectory",
    )
    parser.add_argument(
        "--device-strategy",
        type=str,
        default="small-molecule",
        choices=["excess-compute", "small-molecule", "large-molecule"],
        help="""The compute device allocation strategy.
        excess-compute=Assume compute resources are more than sufficient for
                ViSNet inference. Reserves last GPU for solvent/non-bonded
                computation.
        small-molecule=Maximise resources for ViSNet.
        large-molecule=Maximise resources for ViSNet, while also maximising
                concurrency and usage of GPUs for computation.
        """,
    )
    parser.add_argument(
        "--work-strategy",
        type=str,
        default="combined",
        choices=["combined"],
        help="""The work allocation strategy.
        combined=Distribute work evenly amongst both types of fragments.
        """,
    )
    parser.add_argument(
        "--chunk-size",
        type=int,
        default=9999,
        help="""Define the maximum chunk size (in units of atoms) for
        ACE-NME/dipeptide fragments.  The data will be split and processed
        according to these sizes.
        """,
    )
    parser.add_argument(
        "--frag-nonbonded-calc",
        type=str,
        default="mm",
        choices=["mm", "pme"],
        help="Nonbonded calculator for fragments; required when fragmentation is enabled.",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action='count',
        default=0,
        help="""Verbosity level"""
    )

    # additional arguments
    parser.add_argument(
        "--software-update",
        action='store_true',
        help="""When specified, updates the program before running."""
    )
    parser.add_argument(
        "--download-training-data",
        action='store_true',
        help="""When specified, downloads the AI2BMD training data, and unpacks it in the working directory. Ignores all other options."""
    )
    parser.add_argument(
        "--gpus",
        type=str,
        default=None,
        help="""Specifies the GPU devices to passthrough to the program. Can be one of the following:
        all:        Passthrough all available GPUs to the program.
        none:       Disables GPU passthrough.
        i[,j,k...]  Passthrough some GPUs. Example: --gpus 0,1"""
    )
    return parser.parse_args(argv)

def rewrite_argv(argv, key, value):
    try:
        argv[argv.index(key)+1] = value
    except ValueError:
        argv.append(key)
        argv.append(value)

def remove_argv(argv, key, n):
    try:
        idx = argv.index(key)
        for _ in range(n):
            argv.pop(idx)
    except ValueError:
        pass

def check_env(cmd, message, do_exit=True):
    try:
        # Check if Docker service is running by listing Docker containers
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode == 0:
            return True
    except FileNotFoundError:
        pass
    print(message)
    if do_exit:
        sys.exit(-1)
    else:
        return False

def cat(args):
    return ' '.join(args)

def run(msg, cmd):
    print(msg)
    ret = os.system(cmd)
    if ret != 0:
        print('An error occurred when running the command, aborting.')
        exit(ret)


def main() -> None:
    argv = sys.argv[1:]

    # download training data
    if "--download-training-data" in argv:
        container_name = "ai2bmd_training_data_copy_container"
        run('>>> Downloading AI2BMD training data...',
            f"docker pull {_docker_image_data1} && "
            f"docker pull {_docker_image_data2} && "
            f"docker pull {_docker_image_data3} && "
            f"docker pull {_docker_image_data4} && "
            f"docker pull {_docker_image_data5} && "
            f"docker pull {_docker_image_data6}"
            )
        run('>>> Unpacking data...',
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data1} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-1.tar.xz - | tar x -O | tar xJv &&"
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data2} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-2.tar.xz - | tar x -O | tar xJv &&"
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data3} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-3.tar.xz - | tar x -O | tar xJv &&"
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data4} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-4.tar.xz - | tar x -O | tar xJv &&"
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data5} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-5.tar.xz - | tar x -O | tar xJv &&"
            f"docker container rm {container_name} 1> /dev/null 2> /dev/null || true && "
            f"docker create --name {container_name} {_docker_image_data6} bash && "
            f"docker cp {container_name}:/ai2bmd-training-part-6.tar.xz - | tar x -O | tar xJv"
            )
        run('>>> Cleaning up...',
            f"docker container rm {container_name} && "
            f"docker image rm {_docker_image_data1} && "
            f"docker image rm {_docker_image_data2} && "
            f"docker image rm {_docker_image_data3} && "
            f"docker image rm {_docker_image_data4} && "
            f"docker image rm {_docker_image_data5} && "
            f"docker image rm {_docker_image_data6}"
            )
        print('>>> Done.')
        exit(0)

    args = parse_args(argv)

    # 0. health check
    check_env(['docker', 'ps'], 
              "Docker is not installed or not running. \
               Please see https://docs.docker.com/engine/install/ for installation instructions.")
    gpu_ok = check_env(['nvidia-smi'], 
                       "NVIDIA driver is not installed or not properly loaded. \
                        Please see https://docs.nvidia.com/cuda/cuda-installation-guide-linux/ for installation instructions. \
                        For mainstream GNU/Linux distributions, please refer to section 3, Package Manager Installation.", False)
    gpu_ok = gpu_ok and check_env(['nvidia-container-cli', '--version'], 
                                  "NVIDIA container toolkit is not installed or not configured properly. \
                                   Please see https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html for installation instructions.")

    # 1. check for update
    remove_argv(argv, "--software-update", 1)
    if args.software_update:
        os.system(f"docker pull {_docker_image}")
        # TODO update launcher

    # 2. get mount paths (all mapped to /mnt/abs-path...), rewrite argv
    # 2a. cwd
    cwd = osp.abspath(os.getcwd())
    mount_paths = [cwd]
    docker_cwd = "/mnt"+cwd

    # 2b. base dir
    base_dir = osp.abspath(args.base_dir)
    mount_paths.append(base_dir)
    rewrite_argv(argv, "--base-dir", "/mnt"+base_dir)

    # 2c. log dir
    if args.log_dir is not None:
        log_dir = osp.abspath(args.log_dir)
        mount_paths.append(log_dir)
        rewrite_argv(argv, "--log-dir", "/mnt"+log_dir)

    # 2d. prot file dir
    prot_file = osp.abspath(args.prot_file)
    prot_dir = osp.dirname(prot_file)
    mount_paths.append(prot_dir)
    rewrite_argv(argv, "--prot-file", "/mnt"+prot_file)

    mount_opts = []
    mount_paths.sort()
    for i, mp in enumerate(mount_paths):
        subdir = False
        for j in range(0, i):
            if mp.startswith(mount_paths[j]):
                subdir = True
                break
        if not subdir:
            mount_opts.append('-v')
            mount_opts.append(f'{mp}:/mnt{mp}')

    # 3. GPU option
    remove_argv(argv, "--gpus", 2)
    if args.gpus is None or args.gpus == "all":
        gpu_opts = ["--gpus", "all"]
    else:
        gpu_opts = ["--gpus", f"'\"device={args.gpus}\"'"]
    if not gpu_ok or args.gpus == "none":
        gpu_opts = []

    # 4. Launch the program
    # TODO --user {uid}:{gid}
    # ^ To run as a specific user is harder than just simply setting the uid:gid -- a new Docker image must be built upon
    # the default, so that uid:gid can be dynamically inserted into the image with `useradd`, otherwise some libraries
    # will complain about non-existing user, and non-writable cache paths.
    launch_cmd = f"docker run -w {docker_cwd} {cat(gpu_opts)} --rm {cat(mount_opts)} {_docker_image} python -u /ai2bmd/main.py {cat(argv)}"
    print(f"Launching AI2BMD with Docker: {launch_cmd}")
    os.system(launch_cmd)


if __name__ == "__main__":
    main()